package org.example.project.common.error.handle;

import lombok.extern.slf4j.Slf4j;
import org.example.project.common.contant.Log;
import org.example.project.common.contant.R;
import org.example.project.common.error.MessageEnum;
import org.example.project.common.error.enums.ErrorCode;
import org.example.project.common.error.enums.ErrorCodeMapper;
import org.example.project.common.error.exception.BizCustomException;
import org.example.project.common.error.exception.DaoException;
import org.example.project.common.error.exception.ServiceException;
import org.springframework.boot.autoconfigure.web.ErrorProperties;
import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.autoconfigure.web.servlet.error.AbstractErrorController;
import org.springframework.boot.autoconfigure.web.servlet.error.ErrorViewResolver;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.web.servlet.error.DefaultErrorAttributes;
import org.springframework.boot.web.servlet.error.ErrorAttributes;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.context.request.WebRequest;

import javax.servlet.http.HttpServletRequest;
import java.util.List;
import java.util.Objects;

/**
 * 覆盖默认的BasicErrorController
 *
 * @author theskyzero
 * @see org.springframework.boot.autoconfigure.web.servlet.error.BasicErrorController
 */

@RestController
@RequestMapping("${server.error.path:${error.path:/error}}")
@EnableConfigurationProperties({ServerProperties.class})
@Slf4j
public class ErrorController extends AbstractErrorController {

    private static final String ERROR_ATTRIBUTE = DefaultErrorAttributes.class.getName() + ".ERROR";

    private final ErrorProperties errorProperties;

    public ErrorController(ErrorAttributes errorAttributes, ServerProperties errorProperties,
                           List<ErrorViewResolver> errorViewResolvers) {
        super(errorAttributes, errorViewResolvers);
        Assert.notNull(errorProperties, "ErrorProperties must not be null");
        this.errorProperties = errorProperties.getError();
    }

    @Override
    public String getErrorPath() {
        return this.errorProperties.getPath();
    }

    @RequestMapping
    public ResponseEntity<R> error(HttpServletRequest request) {
        WebRequest webRequest = new ServletWebRequest(request);
        Throwable throwable = getError(webRequest);
        HttpStatus status = getStatus(request);

        if (Objects.isNull(throwable)) {
            Object message = getAttribute(webRequest, "javax.servlet.error.message");
            return ResponseEntity.status(status.value()).body(R.fail(toErrorCode(status), String.valueOf(message)));
        }

        logWarnSimple(throwable);

        ServiceException exception = null;
        if (throwable instanceof BizCustomException) {
            exception = new ServiceException(((BizCustomException) throwable).getMessageEnum(), throwable.getMessage());
        } else if (throwable instanceof ServiceException) {
            exception = (ServiceException) throwable;
        } else if (throwable instanceof DaoException) {
            exception = new ServiceException(((DaoException) throwable).getMessageEnum(), throwable.getMessage());
        } else {
            exception = new ServiceException(ErrorCode.B0001, throwable.getMessage());
        }

        MessageEnum messageEnum = exception.getMessageEnum();
        if (messageEnum instanceof ErrorCode) {
            return ResponseEntity.status(ErrorCodeMapper.toHttpStatus((ErrorCode) messageEnum))
                    .body(R.fail(exception));
        }

        return new ResponseEntity<>(R.fail(exception), status);
    }

    public Throwable getError(WebRequest webRequest) {
        Throwable exception = getAttribute(webRequest, ERROR_ATTRIBUTE);
        if (exception == null) {
            exception = getAttribute(webRequest, "javax.servlet.error.exception");
        }
        return exception;
    }

    @SuppressWarnings("unchecked")
    private <T> T getAttribute(RequestAttributes requestAttributes, String name) {
        return (T) requestAttributes.getAttribute(name, RequestAttributes.SCOPE_REQUEST);
    }

    MessageEnum toErrorCode(HttpStatus httpStatus) {
        return new MessageEnum() {
            @Override
            public String code() {
                return String.valueOf(httpStatus.value());
            }

            @Override
            public String msg() {
                return httpStatus.getReasonPhrase();
            }
        };
    }

    private void logWarnSimple(Throwable exception) {
        if (log.isWarnEnabled()) {
            log.warn(Log.simpleLogBuilder().result(exception.getMessage()).build().toString());
        }
    }
}
